{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型参数的访问、初始化和共享\n",
    "\n",
    "在[“线性回归的简洁实现”](../chapter_deep-learning-basics/linear-regression-gluon.ipynb)一节中，我们通过`init`模块来初始化模型的全部参数。我们也介绍了访问模型参数的简单方法。本节将深入讲解如何访问和初始化模型参数，以及如何在多个层之间共享同一份模型参数。\n",
    "\n",
    "我们先定义一个与上一节中相同的含单隐藏层的多层感知机。我们依然使用默认方式初始化它的参数，并做一次前向计算。与之前不同的是，在这里我们使用torch.nn中的`init`模块，它包含了多种模型初始化方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "net = nn.Sequential()\n",
    "net.add_module(\"hidden\", nn.Linear(20, 256))\n",
    "net.add_module(\"activation\", nn.ReLU())\n",
    "net.add_module(\"output\", nn.Linear(256, 10))\n",
    "\n",
    "X = torch.rand(2, 20)\n",
    "Y = net(X) # 前向计算"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 访问模型参数\n",
    "\n",
    "对于使用`Sequential`类构造的神经网络，我们可以通过方括号`[]`来访问网络的任一层。回忆一下上一节中提到的`Sequential`类与`Module`类的继承关系。对于`Sequential`实例中含模型参数的层，我们可以通过`Module`类的`parameters()`或者`named_parameters()`函数来访问该层包含的所有参数。下面，访问多层感知机`net`中隐藏层的所有参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden.weight Parameter containing:\n",
      "tensor([[-0.2020, -0.0940, -0.1824,  ..., -0.0068,  0.0970,  0.1482],\n",
      "        [ 0.0019,  0.2137,  0.1696,  ..., -0.0510, -0.1160,  0.2125],\n",
      "        [ 0.1995,  0.1376,  0.1833,  ..., -0.1808, -0.2169,  0.1999],\n",
      "        ...,\n",
      "        [-0.0865, -0.0967,  0.1390,  ...,  0.1609, -0.2124,  0.2044],\n",
      "        [ 0.0271, -0.0312, -0.0570,  ..., -0.2049,  0.0278,  0.0193],\n",
      "        [ 0.0395,  0.0134, -0.1902,  ..., -0.1404, -0.0844,  0.1393]],\n",
      "       requires_grad=True)\n",
      "hidden.bias Parameter containing:\n",
      "tensor([ 7.8103e-03,  8.7859e-02,  3.5132e-02, -1.6534e-01, -1.2234e-01,\n",
      "        -1.5377e-01, -1.3543e-01,  4.0223e-02, -2.0627e-01,  5.2887e-02,\n",
      "         1.9170e-01,  3.1859e-02,  1.7611e-01, -4.6360e-02, -6.4578e-02,\n",
      "        -2.1074e-01,  1.5730e-01,  1.6702e-01, -3.8143e-02,  1.2716e-01,\n",
      "        -8.1479e-02,  1.7518e-01, -8.7225e-02,  5.9440e-02,  3.2729e-02,\n",
      "        -1.9102e-01, -6.6407e-02,  5.4632e-02, -1.1433e-02, -1.9550e-01,\n",
      "        -2.1350e-01,  1.9242e-01,  1.3584e-01, -3.5138e-02,  1.7407e-01,\n",
      "        -1.7895e-01, -1.0502e-01,  4.1089e-02,  1.9162e-01, -9.9660e-02,\n",
      "         1.4101e-01,  1.9972e-01, -3.3590e-02,  1.5868e-01,  1.1298e-01,\n",
      "        -6.7102e-03, -3.2781e-03,  1.2014e-01,  1.6524e-02, -1.0324e-01,\n",
      "        -1.6225e-01,  2.8831e-02,  1.8810e-01, -3.9702e-02, -6.5071e-02,\n",
      "         9.2698e-02,  1.7588e-01,  2.0243e-01, -3.5606e-02,  1.6825e-01,\n",
      "        -1.0181e-01, -6.0851e-02,  1.8753e-01,  1.0442e-02,  2.0300e-01,\n",
      "         1.8198e-02,  1.7563e-01,  1.9500e-01, -4.2187e-03,  1.1838e-01,\n",
      "        -2.1104e-01,  3.0405e-02, -7.3956e-02, -1.6577e-01, -7.6479e-02,\n",
      "         1.1049e-01, -1.5387e-01,  5.7465e-02, -1.1561e-01,  1.3009e-01,\n",
      "        -9.2336e-02, -7.6848e-02, -1.4550e-01, -1.5975e-01,  7.0942e-02,\n",
      "        -2.7751e-02, -1.0816e-02, -1.1938e-02,  6.2183e-02,  1.2504e-01,\n",
      "        -1.9769e-01,  3.1202e-03, -1.0244e-01, -1.3156e-01, -2.4525e-02,\n",
      "         2.5341e-02, -1.7999e-01, -1.9250e-01, -7.8844e-02, -2.1585e-01,\n",
      "         8.6968e-02, -9.3047e-02,  7.4182e-02,  7.3537e-02,  1.3643e-01,\n",
      "        -2.1189e-01,  1.3156e-02,  2.8797e-02, -1.2593e-01,  5.9923e-02,\n",
      "         9.3257e-02, -3.5460e-02, -2.2038e-01, -1.9399e-01,  9.6660e-02,\n",
      "        -1.6550e-01, -2.0632e-01, -5.1804e-02,  1.7992e-01,  1.3033e-01,\n",
      "        -1.1642e-03, -2.2084e-01, -8.7178e-02,  8.3951e-02, -4.3986e-02,\n",
      "        -1.7319e-01, -5.1756e-03,  1.6286e-01, -2.0845e-01,  1.0643e-03,\n",
      "         7.2310e-02, -1.0791e-01,  1.5564e-01,  1.8660e-01,  1.7235e-01,\n",
      "        -1.0689e-01,  1.9457e-01,  2.1089e-01,  4.1029e-03,  6.8029e-02,\n",
      "        -1.1916e-01,  1.0269e-01,  1.4901e-01,  1.8469e-02, -1.8406e-02,\n",
      "        -2.1799e-01,  1.7398e-02,  1.6781e-01,  1.2214e-02, -1.7632e-01,\n",
      "         5.5827e-02,  5.5770e-02,  1.5198e-01, -1.0917e-02,  1.5731e-02,\n",
      "         9.9367e-02, -8.0532e-02,  1.7722e-01,  1.8200e-01, -2.0665e-01,\n",
      "         1.6061e-01, -1.4627e-01, -1.9144e-01,  7.1671e-02,  8.9991e-02,\n",
      "         9.5042e-02,  8.8171e-02,  1.4849e-01,  1.7288e-01,  2.0964e-01,\n",
      "         1.4356e-02, -1.3909e-02,  1.6731e-01, -3.1605e-02, -1.6612e-01,\n",
      "        -1.9082e-01,  1.7571e-01,  1.0642e-01,  2.7806e-02,  7.4808e-02,\n",
      "        -1.2854e-01, -5.8718e-02, -1.2209e-01, -1.0403e-01, -8.8480e-02,\n",
      "        -1.6232e-01,  2.1517e-01, -2.2300e-01,  1.1627e-01,  2.0120e-01,\n",
      "         1.4053e-01,  1.2968e-01,  2.0989e-01, -1.8834e-01, -1.4297e-01,\n",
      "        -1.1625e-01, -1.6492e-01,  5.2631e-02,  2.0453e-01,  5.8812e-02,\n",
      "         1.9767e-01, -3.8205e-02, -1.5792e-01,  2.1042e-02,  1.0014e-01,\n",
      "        -2.1404e-01, -5.8073e-02, -4.9840e-02,  1.0239e-01, -7.4449e-02,\n",
      "         7.8836e-02, -8.0697e-02, -1.8739e-01,  7.5313e-02,  1.3986e-01,\n",
      "        -2.1701e-01, -1.9326e-01,  1.4888e-01,  1.7699e-02, -7.7201e-02,\n",
      "         8.0111e-02,  1.3002e-01, -7.8334e-02, -1.1454e-01,  4.0588e-02,\n",
      "        -2.1596e-01, -1.2914e-01,  1.5755e-01, -1.0605e-01, -1.4243e-01,\n",
      "         2.5749e-02, -1.6611e-01,  4.4443e-02, -1.1710e-01,  3.8622e-02,\n",
      "        -5.2730e-02, -1.6144e-01, -1.5020e-01,  1.2373e-01,  6.5923e-02,\n",
      "         3.1857e-02,  6.1026e-02,  1.9089e-01, -1.4960e-01,  5.7682e-02,\n",
      "         1.2209e-01,  9.3395e-02, -1.5064e-04, -9.2602e-02,  4.4977e-02,\n",
      "        -1.5733e-01, -1.4177e-01, -2.2041e-01,  4.9855e-02, -1.1592e-01,\n",
      "        -2.1544e-01], requires_grad=True)\n",
      "output.weight Parameter containing:\n",
      "tensor([[-0.0005,  0.0050,  0.0176,  ...,  0.0410,  0.0416,  0.0422],\n",
      "        [-0.0046,  0.0080, -0.0603,  ...,  0.0147, -0.0106,  0.0549],\n",
      "        [-0.0488, -0.0212,  0.0199,  ...,  0.0535, -0.0559, -0.0281],\n",
      "        ...,\n",
      "        [ 0.0364,  0.0301,  0.0016,  ...,  0.0543,  0.0327, -0.0410],\n",
      "        [ 0.0018,  0.0157, -0.0035,  ..., -0.0067,  0.0570, -0.0507],\n",
      "        [-0.0410, -0.0134,  0.0572,  ...,  0.0388, -0.0371, -0.0286]],\n",
      "       requires_grad=True)\n",
      "output.bias Parameter containing:\n",
      "tensor([ 0.0053, -0.0368, -0.0521, -0.0278, -0.0111,  0.0378,  0.0265, -0.0141,\n",
      "        -0.0245,  0.0381], requires_grad=True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "generator"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    print(name, param)\n",
    "    \n",
    "type(net.named_parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，我们得到了一个返回参数名和参数值的迭代器。其中隐藏层权重参数的名称为`hidden.weight`，它由`net[0]`的名称（`hidden`）和自己的变量名（`weight`）组成。\n",
    "\n",
    "\n",
    "**注：如果单独调用`net[0].named_parameters()`获得的权重参数名为`weight`**\n",
    "\n",
    "\n",
    "为了访问特定参数，我们可以使用`net.state_dict()`来获得一个由参数名映射到参数值的字典（类型为`OrderedDict`）。通过名字来访问字典里的元素，也可以直接使用它的变量名。下面两种方法是等价的，但通常后者的代码可读性更好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.2020, -0.0940, -0.1824,  ..., -0.0068,  0.0970,  0.1482],\n",
       "         [ 0.0019,  0.2137,  0.1696,  ..., -0.0510, -0.1160,  0.2125],\n",
       "         [ 0.1995,  0.1376,  0.1833,  ..., -0.1808, -0.2169,  0.1999],\n",
       "         ...,\n",
       "         [-0.0865, -0.0967,  0.1390,  ...,  0.1609, -0.2124,  0.2044],\n",
       "         [ 0.0271, -0.0312, -0.0570,  ..., -0.2049,  0.0278,  0.0193],\n",
       "         [ 0.0395,  0.0134, -0.1902,  ..., -0.1404, -0.0844,  0.1393]]),\n",
       " Parameter containing:\n",
       " tensor([[-0.2020, -0.0940, -0.1824,  ..., -0.0068,  0.0970,  0.1482],\n",
       "         [ 0.0019,  0.2137,  0.1696,  ..., -0.0510, -0.1160,  0.2125],\n",
       "         [ 0.1995,  0.1376,  0.1833,  ..., -0.1808, -0.2169,  0.1999],\n",
       "         ...,\n",
       "         [-0.0865, -0.0967,  0.1390,  ...,  0.1609, -0.2124,  0.2044],\n",
       "         [ 0.0271, -0.0312, -0.0570,  ..., -0.2049,  0.0278,  0.0193],\n",
       "         [ 0.0395,  0.0134, -0.1902,  ..., -0.1404, -0.0844,  0.1393]],\n",
       "        requires_grad=True))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].state_dict()['weight'], net[0].weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "权重梯度的形状和权重的形状一样。因为我们还没有进行反向传播计算，所以梯度为None。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.grad == None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "类似地，我们可以访问其他层的参数，如输出层的偏差值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([ 0.0053, -0.0368, -0.0521, -0.0278, -0.0111,  0.0378,  0.0265, -0.0141,\n",
       "        -0.0245,  0.0381], requires_grad=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[2].bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化模型参数\n",
    "\n",
    "我们在[“数值稳定性和模型初始化”](../chapter_deep-learning-basics/numerical-stability-and-init.ipynb)一节中描述了模型的默认初始化方法：权重参数元素为[-0.07, 0.07]之间均匀分布的随机数，偏差参数则全为0。但我们经常需要使用其他方法来初始化权重。PyTorch在`nn.init`模块里提供了多种预设的初始化单个参数的方法。在下面的例子中，我们将隐藏层的权重参数初始化成均值为0、标准差为0.01的正态分布随机数，并将偏差参数初始化为常数0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.0113, -0.0137,  0.0041,  0.0029, -0.0041,  0.0011, -0.0069,  0.0024,\n",
       "          0.0012, -0.0144,  0.0010, -0.0104, -0.0050,  0.0058,  0.0140, -0.0094,\n",
       "          0.0044,  0.0118,  0.0015,  0.0139]), tensor(0.))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.init.normal_(net[0].weight, mean=0, std=0.01)  # weight\n",
    "nn.init.constant_(net[0].bias, 0) # bias\n",
    "\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义初始化方法\n",
    "\n",
    "如果需要将模型中所有网络层的参数按照相同的策略进行初始化，可以自定义一个初始化方法（如`weight_init`），然后使用`.apply(weight_init)`进行自定义初始化。在下面的例子里，我们令Linear层的权重有一半概率初始化为0，有另一半概率初始化为$[-10,-5]$和$[5,10]$两个区间里均匀分布的随机数。\n",
    "\n",
    "\n",
    "**注：参考自：<https://discuss.pytorch.org/t/reset-the-parameters-of-a-model/29839>**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init torch.Size([256, 20])\n",
      "Init torch.Size([10, 256])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([-5.5105,  8.5852,  0.0000, -5.6937,  0.0000,  0.0000, -5.8642,  6.8767,\n",
       "        -0.0000, -0.0000, -0.0000, -7.2237, -0.0000,  0.0000, -7.3165,  6.4280,\n",
       "        -0.0000,  8.8226, -0.0000, -0.0000])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# PyTorch不允许对requires_gead=True的tensor做inplace操作。\n",
    "# 所以需要取出weight的 data 属性，对它进行相应的处理\n",
    "def weight_init(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        print('Init', m.weight.shape)\n",
    "        m.weight.data.uniform_(-10, to=10)\n",
    "        m.weight.data *= (m.weight.data.abs() >= 5).float()\n",
    "\n",
    "net.apply(weight_init)\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**注：因为 PyTorch 并没有提供对模型所有参数进行初始化的方法，我们写一个自定义的全局初始化参数的方法加入到 d2ltorch 包中。(该方法来自 [Weights-Initializer-pytorch](https://github.com/3ammor/Weights-Initializer-pytorch) )**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def params_init(model, init, **kwargs): # 本方法已保存在d2ltorch包中方便以后使用\n",
    "\n",
    "    def initializer(m):\n",
    "        if isinstance(m, nn.Conv2d):\n",
    "            init(m.weight.data, **kwargs)\n",
    "            m.bias.data.fill_(0)\n",
    "\n",
    "        elif isinstance(m, nn.Linear):\n",
    "            init(m.weight.data, **kwargs)\n",
    "            m.bias.data.fill_(0)\n",
    "\n",
    "        elif isinstance(m, nn.BatchNorm2d):\n",
    "            m.weight.data.fill_(1.0)\n",
    "            m.bias.data.fill_(0)\n",
    "\n",
    "        elif isinstance(m, nn.BatchNorm1d):\n",
    "            m.weight.data.fill_(1.0)\n",
    "            m.bias.data.fill_(0)\n",
    "\n",
    "    model.apply(initializer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，我们还可以通过`Parameter`类的`data`属性来直接改写模型参数。例如，在下例中我们将隐藏层参数在现有的基础上加1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "13"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-4.5105,  9.5852,  1.0000, -4.6937,  1.0000,  1.0000, -4.8642,  7.8767,\n",
       "         1.0000,  1.0000,  1.0000, -6.2237,  1.0000,  1.0000, -6.3165,  7.4280,\n",
       "         1.0000,  9.8226,  1.0000,  1.0000])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data = net[0].weight.data + 1\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 共享模型参数\n",
    "\n",
    "在有些情况下，我们希望在多个层之间共享模型参数。[“模型构造”](model-construction.ipynb)一节介绍了如何在`Module`类的`forward`函数里多次调用同一个层来计算。\n",
    "\n",
    "**注：经过查找，除了上述方法外，PyTorch并没有提供其他的共享参数并且传播时保持梯度的方法。**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 有多种方法来访问、初始化和共享模型参数。\n",
    "* 可以自定义初始化方法。\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 查阅有关`nn.init`模块的PyTorch文档，了解不同的参数初始化方法。\n",
    "* 尝试在`net`实例化后、`net(X)`前访问模型参数，观察模型参数的形状。\n",
    "* 构造一个含共享参数层的多层感知机并训练。在训练过程中，观察每一层的模型参数和梯度。\n",
    "\n",
    "\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/987)\n",
    "\n",
    "![](../img/qr_parameters.svg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
